// Copyright 2023 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//
// SPDX-License-Identifier: Apache-2.0 OR MIT

//go:build ignore
// +build ignore

package main

// print-jpeg-idct-code.go prints the "func jpeg.decoder.decode_idct" code.
//
// Usage: go run print-jpeg-idct-code.go

import (
	"fmt"
	"os"
	"strings"
)

func main() {
	if err := main1(); err != nil {
		os.Stderr.WriteString(err.Error() + "\n")
		os.Exit(1)
	}
}

func main1() error {
	fmt.Printf("// -------- BEGIN generated by script/print-jpeg-idct-code.go\n")
	fmt.Println()

	fmt.Printf("// p0_298631336 = %s = %10d\n", format8X(p0_298631336), round(p0_298631336))
	fmt.Printf("// p0_390180644 = %s = %10d\n", format8X(p0_390180644), round(p0_390180644))
	fmt.Printf("// p0_509795579 = %s = %10d\n", format8X(p0_509795579), round(p0_509795579))
	fmt.Printf("// p0_541196100 = %s = %10d\n", format8X(p0_541196100), round(p0_541196100))
	fmt.Printf("// p0_601344887 = %s = %10d\n", format8X(p0_601344887), round(p0_601344887))
	fmt.Printf("// p0_765366865 = %s = %10d\n", format8X(p0_765366865), round(p0_765366865))
	fmt.Printf("// p0_785694958 = %s = %10d\n", format8X(p0_785694958), round(p0_785694958))
	fmt.Printf("// p0_899976223 = %s = %10d\n", format8X(p0_899976223), round(p0_899976223))
	fmt.Printf("// p1_175875602 = %s = %10d\n", format8X(p1_175875602), round(p1_175875602))
	fmt.Printf("// p1_306562965 = %s = %10d\n", format8X(p1_306562965), round(p1_306562965))
	fmt.Printf("// p1_501321110 = %s = %10d\n", format8X(p1_501321110), round(p1_501321110))
	fmt.Printf("// p1_847759065 = %s = %10d\n", format8X(p1_847759065), round(p1_847759065))
	fmt.Printf("// p1_961570560 = %s = %10d\n", format8X(p1_961570560), round(p1_961570560))
	fmt.Printf("// p2_053119869 = %s = %10d\n", format8X(p2_053119869), round(p2_053119869))
	fmt.Printf("// p2_562915447 = %s = %10d\n", format8X(p2_562915447), round(p2_562915447))
	fmt.Printf("// p3_072711026 = %s = %10d\n", format8X(p3_072711026), round(p3_072711026))
	fmt.Printf("//\n")
	fmt.Printf("// m0_390180644 = %s = %10d\n", format8X(m0_390180644), round(m0_390180644))
	fmt.Printf("// m0_509795579 = %s = %10d\n", format8X(m0_509795579), round(m0_509795579))
	fmt.Printf("// m0_601344887 = %s = %10d\n", format8X(m0_601344887), round(m0_601344887))
	fmt.Printf("// m0_785694958 = %s = %10d\n", format8X(m0_785694958), round(m0_785694958))
	fmt.Printf("// m0_899976223 = %s = %10d\n", format8X(m0_899976223), round(m0_899976223))
	fmt.Printf("// m1_306562965 = %s = %10d\n", format8X(m1_306562965), round(m1_306562965))
	fmt.Printf("// m1_961570560 = %s = %10d\n", format8X(m1_961570560), round(m1_961570560))
	fmt.Printf("// m2_562915447 = %s = %10d\n", format8X(m2_562915447), round(m2_562915447))
	fmt.Println()

	for x := 0; x < 8; x++ {
		fmt.Println(strings.TrimSpace(replace(pass0, addConsts(map[string]string{
			"$colX$":     fmt.Sprint(x),
			"$row0colX$": fmt.Sprintf("0x%02X", (8*0)|x),
			"$row1colX$": fmt.Sprintf("0x%02X", (8*1)|x),
			"$row2colX$": fmt.Sprintf("0x%02X", (8*2)|x),
			"$row3colX$": fmt.Sprintf("0x%02X", (8*3)|x),
			"$row4colX$": fmt.Sprintf("0x%02X", (8*4)|x),
			"$row5colX$": fmt.Sprintf("0x%02X", (8*5)|x),
			"$row6colX$": fmt.Sprintf("0x%02X", (8*6)|x),
			"$row7colX$": fmt.Sprintf("0x%02X", (8*7)|x),
		}))))
		fmt.Println()
	}

	for y := 0; y < 8; y++ {
		fmt.Println(strings.TrimSpace(replace(pass1, addConsts(map[string]string{
			"$rowY$":         fmt.Sprint(y),
			"$rowYcol0$":     fmt.Sprintf("0x%02X", (8*y)|0),
			"$rowYcol1$":     fmt.Sprintf("0x%02X", (8*y)|1),
			"$rowYcol2$":     fmt.Sprintf("0x%02X", (8*y)|2),
			"$rowYcol3$":     fmt.Sprintf("0x%02X", (8*y)|3),
			"$rowYcol4$":     fmt.Sprintf("0x%02X", (8*y)|4),
			"$rowYcol5$":     fmt.Sprintf("0x%02X", (8*y)|5),
			"$rowYcol6$":     fmt.Sprintf("0x%02X", (8*y)|6),
			"$rowYcol7$":     fmt.Sprintf("0x%02X", (8*y)|7),
			"$bounds_check$": boundsCheck(y == 7),
			"$advance$":      advance(y == 7),
		}))))
		fmt.Println()
	}

	fmt.Printf("// -------- END   generated by script/print-jpeg-idct-code.go\n")
	return nil
}

func replace(s string, m map[string]string) string {
	for k, v := range m {
		s = strings.ReplaceAll(s, k, v)
	}
	return s
}

func addConsts(m map[string]string) map[string]string {
	m["$p0_298631336$"] = format8X(p0_298631336)
	m["$p0_509795579$"] = format8X(p0_509795579)
	m["$p0_541196100$"] = format8X(p0_541196100)
	m["$p0_601344887$"] = format8X(p0_601344887)
	m["$p0_765366865$"] = format8X(p0_765366865)
	m["$p0_785694958$"] = format8X(p0_785694958)
	m["$p1_175875602$"] = format8X(p1_175875602)
	m["$p1_306562965$"] = format8X(p1_306562965)
	m["$p1_501321110$"] = format8X(p1_501321110)
	m["$p1_847759065$"] = format8X(p1_847759065)
	m["$p2_053119869$"] = format8X(p2_053119869)
	m["$p3_072711026$"] = format8X(p3_072711026)

	m["$m0_390180644$"] = format8X(m0_390180644)
	m["$m0_509795579$"] = format8X(m0_509795579)
	m["$m0_601344887$"] = format8X(m0_601344887)
	m["$m0_785694958$"] = format8X(m0_785694958)
	m["$m0_899976223$"] = format8X(m0_899976223)
	m["$m1_306562965$"] = format8X(m1_306562965)
	m["$m1_961570560$"] = format8X(m1_961570560)
	m["$m2_562915447$"] = format8X(m2_562915447)

	return m
}

func format8X(x float64) string {
	s := fmt.Sprintf("0x%08X", round(x))
	return s[:6] + "_" + s[6:]
}

func round(x float64) uint32 {
	// Hard-code some rounding errors to match libjpeg-turbo exactly.
	//
	// For example, p1_306562965 is defined precisely as (sqrt2 * (+cos2pi16)).
	// Scaling by 8192 gives 10703.36380826727651620162087671551716... which
	// rounds down (as a fixed precision approximation) to 10703 = 0x29CF.
	//
	// Conceptually, m1_306562965 is the negative of that 0x29CF, but it
	// appears in libjpeg-turbo's algorithm as the negative of (p1_847759065 -
	// p0_541196100). Those two constants (after scaling by 8192 and rounding
	// *separately*) fold as (15137 - 4433) = 0x29D0, differing by 1.
	adjustment := int32(0)
	switch x {
	case p0_509795579:
		adjustment = +1
	case p0_785694958:
		adjustment = +1
	case m0_601344887:
		adjustment = -1
	case m1_306562965:
		adjustment = -1
	}

	return uint32(adjustment) + uint32(0.5+(x*(1<<13)))
}

func boundsCheck(final bool) string {
	if final {
		return "" +
			"if 8 > args.dst_buffer.length() {\n" +
			"    return nothing\n" +
			"}"
	}
	return "" +
		"if args.dst_stride > args.dst_buffer.length() {\n" +
		"    return nothing\n" +
		"}\n" +
		`assert 8 <= args.dst_buffer.length() via "a <= b: a <= c; c <= b"(c: args.dst_stride)`
}

func advance(final bool) string {
	if final {
		return ""
	}
	return "args.dst_buffer = args.dst_buffer[args.dst_stride ..]"
}

const (
	sqrt2 = 1.4142135623730950488016887242096980785696718753769480731766797379

	// cosNpi16 ≈ cos(N * (pi / 16)).
	cos1pi16 = 0.9807852804032304491261822361342390369739337308933360950029160885
	cos2pi16 = 0.9238795325112867561281831893967882868224166258636424861150977312
	cos3pi16 = 0.8314696123025452370787883776179057567385608119872499634461245902
	cos5pi16 = 0.5555702330196022247428308139485328743749371907548040459241535282
	cos6pi16 = 0.3826834323650897717284599840303988667613445624856270414338006356
	cos7pi16 = 0.1950903220161282678482848684770222409276916177519548077545020894
)

const (
	p0_541196100 = sqrt2 * (+cos6pi16)
	p0_785694958 = sqrt2 * (+cos5pi16)
	p1_175875602 = sqrt2 * (+cos3pi16)
	p1_306562965 = sqrt2 * (+cos2pi16)

	p0_601344887 = sqrt2 * (+cos1pi16 - cos5pi16)

	p0_765366865 = sqrt2 * (+cos2pi16 - cos6pi16)
	p1_847759065 = sqrt2 * (+cos2pi16 + cos6pi16)

	p0_390180644 = sqrt2 * (+cos3pi16 - cos5pi16)
	p0_899976223 = sqrt2 * (+cos3pi16 - cos7pi16)
	p1_961570560 = sqrt2 * (+cos3pi16 + cos5pi16)
	p2_562915447 = sqrt2 * (+cos3pi16 + cos1pi16)

	p0_509795579 = sqrt2 * (+cos5pi16 - cos7pi16)

	p0_298631336 = sqrt2 * (-cos1pi16 + cos3pi16 + cos5pi16 - cos7pi16)
	p1_501321110 = sqrt2 * (+cos1pi16 + cos3pi16 - cos5pi16 - cos7pi16)
	p2_053119869 = sqrt2 * (+cos1pi16 + cos3pi16 - cos5pi16 + cos7pi16)
	p3_072711026 = sqrt2 * (+cos1pi16 + cos3pi16 + cos5pi16 - cos7pi16)

	m0_390180644 = (1 << (32 - 13)) - p0_390180644
	m0_509795579 = (1 << (32 - 13)) - p0_509795579
	m0_601344887 = (1 << (32 - 13)) - p0_601344887
	m0_785694958 = (1 << (32 - 13)) - p0_785694958
	m0_899976223 = (1 << (32 - 13)) - p0_899976223
	m1_306562965 = (1 << (32 - 13)) - p1_306562965
	m1_961570560 = (1 << (32 - 13)) - p1_961570560
	m2_562915447 = (1 << (32 - 13)) - p2_562915447
)

const pass0 = `// ==== First pass, column $colX$.

if (0 == (
        this.mcu_blocks[0][$row1colX$] |
        this.mcu_blocks[0][$row2colX$] |
        this.mcu_blocks[0][$row3colX$] |
        this.mcu_blocks[0][$row4colX$] |
        this.mcu_blocks[0][$row5colX$] |
        this.mcu_blocks[0][$row6colX$] |
        this.mcu_blocks[0][$row7colX$])) {
// Fast path when the 1-dimensional AC terms are all zero.

intermediate[$row0colX$] =
        (this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row0colX$]) ~mod*
        (this.quant_tables[args.q][$row0colX$] as base.u32)) ~mod<< 2
intermediate[$row1colX$] = intermediate[$row0colX$]
intermediate[$row2colX$] = intermediate[$row0colX$]
intermediate[$row3colX$] = intermediate[$row0colX$]
intermediate[$row4colX$] = intermediate[$row0colX$]
intermediate[$row5colX$] = intermediate[$row0colX$]
intermediate[$row6colX$] = intermediate[$row0colX$]
intermediate[$row7colX$] = intermediate[$row0colX$]

} else {
// Even rows.

bq2 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row2colX$]) ~mod* (this.quant_tables[args.q][$row2colX$] as base.u32)
bq6 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row6colX$]) ~mod* (this.quant_tables[args.q][$row6colX$] as base.u32)

// This code...
ca = (bq2 ~mod+ bq6) ~mod* $p0_541196100$
cb2 = ca ~mod+ (bq2 ~mod* $p0_765366865$)
cb6 = ca ~mod- (bq6 ~mod* $p1_847759065$)
// ...is equivalent to this more-SIMD-like code.
//
// cb2 = (bq2 ~mod* $p1_306562965$) ~mod+ (bq6 ~mod* $p0_541196100$)
// cb6 = (bq2 ~mod* $p0_541196100$) ~mod+ (bq6 ~mod* $m1_306562965$)

bq0 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row0colX$]) ~mod* (this.quant_tables[args.q][$row0colX$] as base.u32)
bq4 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row4colX$]) ~mod* (this.quant_tables[args.q][$row4colX$] as base.u32)

ccp = (bq0 ~mod+ bq4) ~mod<< 13
ccm = (bq0 ~mod- bq4) ~mod<< 13

cd0 = ccp ~mod+ cb2
cd1 = ccm ~mod+ cb6
cd2 = ccm ~mod- cb6
cd3 = ccp ~mod- cb2

// Odd rows.

bq1 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row1colX$]) ~mod* (this.quant_tables[args.q][$row1colX$] as base.u32)
bq3 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row3colX$]) ~mod* (this.quant_tables[args.q][$row3colX$] as base.u32)
bq5 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row5colX$]) ~mod* (this.quant_tables[args.q][$row5colX$] as base.u32)
bq7 = this.util.sign_extend_convert_u16_u32(a: this.mcu_blocks[0][$row7colX$]) ~mod* (this.quant_tables[args.q][$row7colX$] as base.u32)

ci51 = bq5 ~mod+ bq1
ci53 = bq5 ~mod+ bq3
ci71 = bq7 ~mod+ bq1
ci73 = bq7 ~mod+ bq3

// This code...
cj = (ci73 ~mod+ ci51) ~mod* $p1_175875602$
ck1 = bq1 ~mod* $p1_501321110$
ck3 = bq3 ~mod* $p3_072711026$
ck5 = bq5 ~mod* $p2_053119869$
ck7 = bq7 ~mod* $p0_298631336$
ci51 ~mod*= $m0_390180644$
ci53 ~mod*= $m2_562915447$
ci71 ~mod*= $m0_899976223$
ci73 ~mod*= $m1_961570560$
cl51 = ci51 ~mod+ cj
cl73 = ci73 ~mod+ cj
ck1 ~mod+= ci71 ~mod+ cl51
ck3 ~mod+= ci53 ~mod+ cl73
ck5 ~mod+= ci53 ~mod+ cl51
ck7 ~mod+= ci71 ~mod+ cl73
// ...is equivalent to this more-SIMD-like code.
//
// cl73 = (ci73 ~mod* $m0_785694958$) ~mod+ (ci51 ~mod* $p1_175875602$)
// cl51 = (ci73 ~mod* $p1_175875602$) ~mod+ (ci51 ~mod* $p0_785694958$)
// ck1 = cl51 ~mod+ ((bq1 ~mod* $p0_601344887$) ~mod+ (bq7 ~mod* $m0_899976223$))
// ck3 = cl73 ~mod+ ((bq3 ~mod* $p0_509795579$) ~mod+ (bq5 ~mod* $m2_562915447$))
// ck5 = cl51 ~mod+ ((bq3 ~mod* $m2_562915447$) ~mod+ (bq5 ~mod* $m0_509795579$))
// ck7 = cl73 ~mod+ ((bq1 ~mod* $m0_899976223$) ~mod+ (bq7 ~mod* $m0_601344887$))

// Combine rows.

intermediate[$row0colX$] = this.util.sign_extend_rshift_u32(a: (cd0 ~mod+ ck1) ~mod+ (1 << 10), n: 11)
intermediate[$row7colX$] = this.util.sign_extend_rshift_u32(a: (cd0 ~mod- ck1) ~mod+ (1 << 10), n: 11)
intermediate[$row1colX$] = this.util.sign_extend_rshift_u32(a: (cd1 ~mod+ ck3) ~mod+ (1 << 10), n: 11)
intermediate[$row6colX$] = this.util.sign_extend_rshift_u32(a: (cd1 ~mod- ck3) ~mod+ (1 << 10), n: 11)
intermediate[$row2colX$] = this.util.sign_extend_rshift_u32(a: (cd2 ~mod+ ck5) ~mod+ (1 << 10), n: 11)
intermediate[$row5colX$] = this.util.sign_extend_rshift_u32(a: (cd2 ~mod- ck5) ~mod+ (1 << 10), n: 11)
intermediate[$row3colX$] = this.util.sign_extend_rshift_u32(a: (cd3 ~mod+ ck7) ~mod+ (1 << 10), n: 11)
intermediate[$row4colX$] = this.util.sign_extend_rshift_u32(a: (cd3 ~mod- ck7) ~mod+ (1 << 10), n: 11)
}
`

const pass1 = `// ==== Second pass, row $rowY$.

if (0 == (
        intermediate[$rowYcol1$] |
        intermediate[$rowYcol2$] |
        intermediate[$rowYcol3$] |
        intermediate[$rowYcol4$] |
        intermediate[$rowYcol5$] |
        intermediate[$rowYcol6$] |
        intermediate[$rowYcol7$])) {
// Fast path when the 1-dimensional AC terms are all zero.

$bounds_check$

args.dst_buffer[0] = BIAS_AND_CLAMP[((intermediate[$rowYcol0$] ~mod+ (1 << 4)) >> 5) & 1023]
args.dst_buffer[1] = args.dst_buffer[0]
args.dst_buffer[2] = args.dst_buffer[0]
args.dst_buffer[3] = args.dst_buffer[0]
args.dst_buffer[4] = args.dst_buffer[0]
args.dst_buffer[5] = args.dst_buffer[0]
args.dst_buffer[6] = args.dst_buffer[0]
args.dst_buffer[7] = args.dst_buffer[0]

$advance$

} else {
// Even columns.

in2 = intermediate[$rowYcol2$]
in6 = intermediate[$rowYcol6$]

// This code...
ra = (in2 ~mod+ in6) ~mod* $p0_541196100$
rb2 = ra ~mod+ (in2 ~mod* $p0_765366865$)
rb6 = ra ~mod- (in6 ~mod* $p1_847759065$)
// ...is equivalent to this more-SIMD-like code.
//
// rb2 = (in2 ~mod* $p1_306562965$) ~mod+ (in6 ~mod* $p0_541196100$)
// rb6 = (in2 ~mod* $p0_541196100$) ~mod+ (in6 ~mod* $m1_306562965$)

in0 = intermediate[$rowYcol0$]
in4 = intermediate[$rowYcol4$]

rcp = (in0 ~mod+ in4) ~mod<< 13
rcm = (in0 ~mod- in4) ~mod<< 13

rd0 = rcp ~mod+ rb2
rd1 = rcm ~mod+ rb6
rd2 = rcm ~mod- rb6
rd3 = rcp ~mod- rb2

// Odd columns.

in1 = intermediate[$rowYcol1$]
in3 = intermediate[$rowYcol3$]
in5 = intermediate[$rowYcol5$]
in7 = intermediate[$rowYcol7$]

ri51 = in5 ~mod+ in1
ri53 = in5 ~mod+ in3
ri71 = in7 ~mod+ in1
ri73 = in7 ~mod+ in3

// This code...
rj = (ri73 ~mod+ ri51) ~mod* $p1_175875602$
rk1 = in1 ~mod* $p1_501321110$
rk3 = in3 ~mod* $p3_072711026$
rk5 = in5 ~mod* $p2_053119869$
rk7 = in7 ~mod* $p0_298631336$
ri51 ~mod*= $m0_390180644$
ri53 ~mod*= $m2_562915447$
ri71 ~mod*= $m0_899976223$
ri73 ~mod*= $m1_961570560$
rl51 = ri51 ~mod+ rj
rl73 = ri73 ~mod+ rj
rk1 ~mod+= ri71 ~mod+ rl51
rk3 ~mod+= ri53 ~mod+ rl73
rk5 ~mod+= ri53 ~mod+ rl51
rk7 ~mod+= ri71 ~mod+ rl73
// ...is equivalent to this more-SIMD-like code.
//
// rl73 = (ri73 ~mod* $m0_785694958$) ~mod+ (ri51 ~mod* $p1_175875602$)
// rl51 = (ri73 ~mod* $p1_175875602$) ~mod+ (ri51 ~mod* $p0_785694958$)
// rk1 = rl51 ~mod+ ((in1 ~mod* $p0_601344887$) ~mod+ (in7 ~mod* $m0_899976223$))
// rk3 = rl73 ~mod+ ((in3 ~mod* $p0_509795579$) ~mod+ (in5 ~mod* $m2_562915447$))
// rk5 = rl51 ~mod+ ((in3 ~mod* $m2_562915447$) ~mod+ (in5 ~mod* $m0_509795579$))
// rk7 = rl73 ~mod+ ((in1 ~mod* $m0_899976223$) ~mod+ (in7 ~mod* $m0_601344887$))

// Combine columns.

$bounds_check$

args.dst_buffer[0] = BIAS_AND_CLAMP[(((rd0 ~mod+ rk1) ~mod+ (1 << 17)) >> 18) & 1023]
args.dst_buffer[7] = BIAS_AND_CLAMP[(((rd0 ~mod- rk1) ~mod+ (1 << 17)) >> 18) & 1023]
args.dst_buffer[1] = BIAS_AND_CLAMP[(((rd1 ~mod+ rk3) ~mod+ (1 << 17)) >> 18) & 1023]
args.dst_buffer[6] = BIAS_AND_CLAMP[(((rd1 ~mod- rk3) ~mod+ (1 << 17)) >> 18) & 1023]
args.dst_buffer[2] = BIAS_AND_CLAMP[(((rd2 ~mod+ rk5) ~mod+ (1 << 17)) >> 18) & 1023]
args.dst_buffer[5] = BIAS_AND_CLAMP[(((rd2 ~mod- rk5) ~mod+ (1 << 17)) >> 18) & 1023]
args.dst_buffer[3] = BIAS_AND_CLAMP[(((rd3 ~mod+ rk7) ~mod+ (1 << 17)) >> 18) & 1023]
args.dst_buffer[4] = BIAS_AND_CLAMP[(((rd3 ~mod- rk7) ~mod+ (1 << 17)) >> 18) & 1023]

$advance$
}
`
